/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "common/format/axis_name_util.h"
#include "common/format/axis_util.h"

namespace fe {
const std::map<ge::Format, GetAxisNameByAxisValueInfoPtr> AxisNameUtil::get_axis_name_except_func_map = {
    {ge::FORMAT_NCHW, std::make_shared<GetAxisNameByAxisValueInfo>(GetNCHWAxisExceptName)},
    {ge::FORMAT_NHWC, std::make_shared<GetAxisNameByAxisValueInfo>(GetNHWCAxisExceptName)},
    {ge::FORMAT_HWCN, std::make_shared<GetAxisNameByAxisValueInfo>(GetHWCNAxisExceptName)},
    {ge::FORMAT_CHWN, std::make_shared<GetAxisNameByAxisValueInfo>(GetCHWNAxisExceptName)}};

std::string AxisNameUtil::AxisNameToStr(std::vector<std::string> &axis_name) {
  std::string str;
  if (axis_name.empty()) {
    return str;
  }

  for (size_t i = 0; i < axis_name.size(); i++) {
    str += axis_name[i];
  }
  return str;
}

/** get reshape type according to format and axis value of reduce op
 *  1. get axis name except for reduce axis value,
 *     format: NCHW, axis_values: [0,1],
 *     the axis name is HW
 *  2. get reshape type according to axis name.
 *  the axis_except is [0, 3] */
std::string AxisNameUtil::GetReshapeType(const ge::Format &format, std::vector<int64_t> &axis_values) {
  std::string reshape_type;
  if (axis_values.empty()) {
    FE_LOGD("axis value is empty, return default reshape type.");
    return reshape_type;
  }
  vector<std::string> axis_names;
  // get axis name except for reduce axis
  auto iter_get_axis_func = get_axis_name_except_func_map.find(format);
  if (iter_get_axis_func == get_axis_name_except_func_map.end()) {
    FE_LOGW("Can not get axis name of old format %u!", format);
    return reshape_type;
  }
  GetAxisNameByAxisValueInfoPtr get_axis_func = iter_get_axis_func->second;
  if (get_axis_func == nullptr) {
    return reshape_type;
  }
  (void)(*get_axis_func)(axis_values, axis_names);
  if (axis_names.empty()) {
    FE_LOGD("axis name is empty, return default reshape type.");
    return reshape_type;
  }
  return AxisNameToStr(axis_names);
}

/** get value except redcue axis
 *  for example, a reduce op, its format is NCHW, axis value is [1, 2]
 *  the axis_except is [0, 3] */
std::vector<int64_t> AxisNameUtil::GetExceptAxisValue(vector<int64_t> &axis_values, const size_t &axis_nums) {
  std::vector<int64_t> axis_except;
  for (size_t i = 0; i < axis_nums; i++) {
    auto iter = std::find(axis_values.begin(), axis_values.end(), i);
    if (iter != axis_values.end()) {
      continue;
    }
    axis_except.emplace_back(i);
  }
  return axis_except;
}

Status AxisNameUtil::GetNCHWAxisExceptName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  std::vector<int64_t> except_axis = GetExceptAxisValue(axis_values, DIM_DEFAULT_SIZE);
  for (size_t i = 0; i < except_axis.size(); i++) {
    int64_t axis_value_temp = except_axis[i];
    if (axis_value_temp == NCHW_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNHWCAxisExceptName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  std::vector<int64_t> except_axis = GetExceptAxisValue(axis_values, DIM_DEFAULT_SIZE);
  for (size_t i = 0; i < except_axis.size(); i++) {
    int64_t axis_value_temp = except_axis[i];
    if (axis_value_temp == NHWC_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetHWCNAxisExceptName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  std::vector<int64_t> except_axis = GetExceptAxisValue(axis_values, DIM_DEFAULT_SIZE);
  for (size_t i = 0; i < except_axis.size(); i++) {
    int64_t axis_value_temp = except_axis[i];
    if (axis_value_temp == HWCN_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetCHWNAxisExceptName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  std::vector<int64_t> except_axis = GetExceptAxisValue(axis_values, DIM_DEFAULT_SIZE);
  for (size_t i = 0; i < except_axis.size(); i++) {
    int64_t axis_value_temp = except_axis[i];
    if (axis_value_temp == CHWN_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNewAxisAttributeValue(const ge::OpDesc &op_desc, const ge::Format &origin_format,
                                              const ge::Format &current_format, const ge::GeShape &origin_shape,
                                              std::vector<int64_t> &axis_index_vec) {
  // get old axis name
  std::vector<std::string> axis_names;
  if (GetOriginalAxisName(op_desc, origin_format, origin_shape, axis_names) != SUCCESS) {
    REPORT_INNER_ERROR(EM_INNER_ERROR,
                       "[GraphOpt][SetAxis][GetAxisName][Op %s,type=%s]:Get axis name for ori format %u failed!",
                       op_desc.GetName().c_str(), op_desc.GetType().c_str(), origin_format);
    FE_LOGW("[GraphOpt][SetAxis][GetAxisName][Op name=%s,type=%s]:Get axis name for format %u failed!",
            op_desc.GetName().c_str(), op_desc.GetType().c_str(), origin_format);
    return FAILED;
  }
  // get new axis info
  if (GetNewAxisInfoByName(op_desc, current_format, origin_shape, axis_names, axis_index_vec) != SUCCESS) {
    REPORT_INNER_ERROR(EM_INNER_ERROR,
                       "[GraphOpt][SetAxis][GetAxisName][Op %s,type=%s]:Get axis name for current format %u failed!",
                       op_desc.GetName().c_str(), op_desc.GetType().c_str(), current_format);
    FE_LOGW("[GraphOpt][SetAxis][GetAxisName][Op name=%s,type=%s]:Get axis name for ori format %u failed!",
            op_desc.GetName().c_str(), op_desc.GetType().c_str(), current_format);
    return FAILED;
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNewAxisInfoByName(const ge::OpDesc &op_desc, const ge::Format &format,
                                          const ge::GeShape &origin_shape, std::vector<std::string> &axis_name,
                                          std::vector<int64_t> &axis_index_vec) {
  for (const auto &i : axis_name) {
    auto iter = FORMAT_AXIS_NAME_NUMBER_MAP.find(format);
    if (iter != FORMAT_AXIS_NAME_NUMBER_MAP.end()) {
      auto axis_name_number_map = iter->second;
      auto iter_axis_number = axis_name_number_map.find(i);
      if (iter_axis_number != axis_name_number_map.end()) {
        for (auto ele : iter_axis_number->second) {
            axis_index_vec.emplace_back(ele);
        }
      }
    }
  }

  for (const auto &axis_index:axis_index_vec) {
    FE_LOGD("Get reduce op [%s] axis new value is [%ld].", op_desc.GetName().c_str(), axis_index);
  }
  return SUCCESS;
}

Status AxisNameUtil::GetOriginalAxisName(const ge::OpDesc &op_desc, const ge::Format &format,
                                         const ge::GeShape &origin_shape, std::vector<std::string> &axis_name_vec) {
  Status ret = FAILED;
  std::vector<int64_t> axis_index_vec;

  if (AxisUtil::GetOriginAxisAttribute(op_desc, origin_shape, axis_index_vec) != SUCCESS) {
    FE_LOGW("Get reduce op [%s] new axis info failed!", op_desc.GetName().c_str());
    return FAILED;
  }

  if (format == ge::FORMAT_NCHW) {
    ret = GetNCHWAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_NHWC) {
    ret = GetNHWCAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_HWCN) {
    ret = GetHWCNAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_CHWN) {
    ret = GetCHWNAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_NDHWC) {
    ret = GetNDHWCAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_NCDHW) {
    ret = GetNCDHWAxisName(axis_index_vec, axis_name_vec);
  } else if (format == ge::FORMAT_DHWCN) {
    ret = GetDHWCNAxisName(axis_index_vec, axis_name_vec);
  }

  for (const auto &axis_name: axis_name_vec) {
    FE_LOGD("Get reduce op [%s] axis name is [%s].", op_desc.GetName().c_str(), axis_name.c_str());
  }
  return ret;
}

Status AxisNameUtil::GetNCHWAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == NCHW_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NCHW_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNHWCAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == NHWC_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == NHWC_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetHWCNAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == HWCN_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == HWCN_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetCHWNAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == CHWN_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == CHWN_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNDHWCAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == NDHWC_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == NDHWC_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NDHWC_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == NDHWC_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NDHWC_DIM_D) {
      axis_name.emplace_back(D_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetNCDHWAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (uint32_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == NCDHW_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == NCDHW_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == NCDHW_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == NCDHW_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == NCDHW_DIM_D) {
      axis_name.emplace_back(D_AXIS_NAME);
    }
  }
  return SUCCESS;
}

Status AxisNameUtil::GetDHWCNAxisName(std::vector<int64_t> &axis_values, std::vector<std::string> &axis_name) {
  for (size_t i = 0; i < axis_values.size(); i++) {
    int64_t axis_value_temp = axis_values[i];
    if (axis_value_temp == DHWCN_DIM_C) {
      axis_name.emplace_back(C_AXIS_NAME);
    } else if (axis_value_temp == DHWCN_DIM_H) {
      axis_name.emplace_back(H_AXIS_NAME);
    } else if (axis_value_temp == DHWCN_DIM_W) {
      axis_name.emplace_back(W_AXIS_NAME);
    } else if (axis_value_temp == DHWCN_DIM_N) {
      axis_name.emplace_back(N_AXIS_NAME);
    } else if (axis_value_temp == DHWCN_DIM_D) {
      axis_name.emplace_back(D_AXIS_NAME);
    }
  }
  return SUCCESS;
}
};  // namespace fe
